{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SRNN on Speech Commands Dataset\n",
    "\n",
    "\n",
    "Please use `fetch_google.sh` to download the Google Speech Commands Dataset and python `process_google.py` to create feature extracted data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-17T17:19:40.265370Z",
     "start_time": "2019-07-17T17:19:39.980681Z"
    }
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-17T17:19:40.273918Z",
     "start_time": "2019-07-17T17:19:40.267871Z"
    }
   },
   "outputs": [],
   "source": [
    "from edgeml_pytorch.graph.rnn import SRNN2\n",
    "from edgeml_pytorch.trainer.srnnTrainer import SRNNTrainer\n",
    "import edgeml_pytorch.utils as utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-17T17:19:40.340949Z",
     "start_time": "2019-07-17T17:19:40.275892Z"
    }
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "DATA_DIR = './GoogleSpeech/Extracted/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-17T17:19:45.603764Z",
     "start_time": "2019-07-17T17:19:40.343295Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train shape (99, 51088, 32) (51088, 13)\n",
      "Val shape (99, 6798, 32) (6798, 13)\n",
      "Test shape (99, 6835, 32) (6835, 13)\n"
     ]
    }
   ],
   "source": [
    "x_train_, y_train = np.load(DATA_DIR + 'x_train.npy'), np.load(DATA_DIR + 'y_train.npy')\n",
    "x_val_, y_val = np.load(DATA_DIR + 'x_val.npy'), np.load(DATA_DIR + 'y_val.npy')\n",
    "x_test_, y_test = np.load(DATA_DIR + 'x_test.npy'), np.load(DATA_DIR + 'y_test.npy')\n",
    "# Mean-var normalize\n",
    "mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
    "std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
    "std[std[:] < 0.000001] = 1\n",
    "x_train_ = (x_train_ - mean) / std\n",
    "x_val_ = (x_val_ - mean) / std\n",
    "x_test_ = (x_test_ - mean) / std\n",
    "\n",
    "x_train = np.swapaxes(x_train_, 0, 1)\n",
    "x_val = np.swapaxes(x_val_, 0, 1)\n",
    "x_test = np.swapaxes(x_test_, 0, 1)\n",
    "print(\"Train shape\", x_train.shape, y_train.shape)\n",
    "print(\"Val shape\", x_val.shape, y_val.shape)\n",
    "print(\"Test shape\", x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-17T17:19:45.610070Z",
     "start_time": "2019-07-17T17:19:45.606282Z"
    }
   },
   "outputs": [],
   "source": [
    "numTimeSteps = x_train.shape[0]\n",
    "numInput = x_train.shape[-1]\n",
    "numClasses = y_train.shape[1]\n",
    "\n",
    "# Network Parameters\n",
    "brickSize = 11\n",
    "hiddenDim0 = 64\n",
    "hiddenDim1 = 32\n",
    "cellType = 'LSTM'\n",
    "learningRate = 0.01\n",
    "batchSize = 128\n",
    "epochs = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-17T17:19:49.576076Z",
     "start_time": "2019-07-17T17:19:45.612629Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using x-entropy loss\n"
     ]
    }
   ],
   "source": [
    "srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device) \n",
    "trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-07-17T17:20:28.680161Z",
     "start_time": "2019-07-17T17:19:49.578246Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 batch 0 loss 4.295151 acc 0.031250\n",
      "Epoch 0 batch 200 loss 1.002617 acc 0.718750\n",
      "Epoch 1 batch 0 loss 0.647069 acc 0.796875\n",
      "Epoch 1 batch 200 loss 0.469229 acc 0.835938\n",
      "Epoch 2 batch 0 loss 0.388671 acc 0.882812\n",
      "Epoch 2 batch 200 loss 0.396696 acc 0.859375\n",
      "Epoch 3 batch 0 loss 0.266433 acc 0.921875\n",
      "Epoch 3 batch 200 loss 0.281694 acc 0.867188\n",
      "Epoch 4 batch 0 loss 0.302240 acc 0.906250\n",
      "Epoch 4 batch 200 loss 0.245797 acc 0.929688\n",
      "Validation accuracy: 0.911003\n",
      "Epoch 5 batch 0 loss 0.202542 acc 0.945312\n",
      "Epoch 5 batch 200 loss 0.192004 acc 0.929688\n",
      "Epoch 6 batch 0 loss 0.256735 acc 0.921875\n",
      "Epoch 6 batch 200 loss 0.279066 acc 0.921875\n",
      "Epoch 7 batch 0 loss 0.228837 acc 0.945312\n",
      "Epoch 7 batch 200 loss 0.222357 acc 0.937500\n",
      "Epoch 8 batch 0 loss 0.164639 acc 0.960938\n",
      "Epoch 8 batch 200 loss 0.160117 acc 0.945312\n",
      "Epoch 9 batch 0 loss 0.173849 acc 0.953125\n",
      "Epoch 9 batch 200 loss 0.201694 acc 0.929688\n",
      "Validation accuracy: 0.912474\n"
     ]
    }
   ],
   "source": [
    "trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val, printStep=200, valStep=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
